package face5wap;

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
import org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import util.ShowUtils;

import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;

public class LS_WGAN {
    private static final Logger log = LoggerFactory.getLogger(LS_WGAN.class);

    private static final int batchSizePerWorker = 200;
    private static final int batchSizePred = 500;
    private static final int labelIndex = 784;
    private static final int numClasses = 10; // Using Softmax.
    private static final int numClassesDis = 1; // Using Sigmoid.
    private static final int numFeatures = 784;
    private static final int numIterations = 10000;
    private static final int numGenSamples = 10; // This will be a grid so effectively we get {numGenSamples * numGenSamples} samples.
    private static final int numLinesToSkip = 0;
    private static final int numberOfTheBeast = 666;
    private static final int printEvery = 10;
    private static final int saveEvery = 100;
    private static final int tensorDimOneSize = 28;
    private static final int tensorDimTwoSize = 28;
    private static final int tensorDimThreeSize = 1;
    private static final int zSize = 2;
    private static final int latent_dim = 100;

    private static final double dis_learning_rate = 0.002;
    //private static final double dis_learning_rate = 0.01;
    private static final double frozen_learning_rate = 0.0;
    private static final double gen_learning_rate = 0.004;

    private static final String delimiter = ",";
    private static final String resPath = "F:/face/model/";
    private static final String newLine = "\n";
    private static final String dataSetName = "mnist";

    private static final boolean useGpu = false;

    private static final IUpdater UPDATER = Adam.builder().learningRate(dis_learning_rate).beta1(0.5).build();

    public static void main(String[] args) throws Exception {
        //new dl4jGANComputerVision().GAN(args);
        gan();
    }

    private static void  gan() throws Exception {
        //训练使用的后台 cup还是gpu
        System.out.println(Nd4j.getBackend());
        Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
        log.info("Unfrozen discriminator!");
        // dis网络用于生成数据真假的判别
        //1.接受 gen生成的数据和真实的数据
        //2.learningrate 不是固定的。
        ComputationGraph dis = new ComputationGraph(new NeuralNetConfiguration.Builder()
                //始终使用工作空间
                /*.trainingWorkspaceMode(WorkspaceMode.ENABLED)
                //推断空间
                .inferenceWorkspaceMode(WorkspaceMode.ENABLED)*/
                .seed(numberOfTheBeast)
                //优化算法  - 数据 下降可导
               // .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                //梯度标准化策略
               // .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
                //在 -1，1之间
               // .gradientNormalizationThreshold(1.0)
                //.l2(0.0001)
                //激活函数
                //.activation(Activation.TANH)
                //Weight初始化
                .weightInit(WeightInit.XAVIER)
                .graphBuilder()
                .addInputs("input2")
                .setInputTypes(InputType.convolutionalFlat(tensorDimOneSize, tensorDimTwoSize, tensorDimThreeSize))
               /* .inputPreProcessor("dis_dense_layer_1", new FeedForwardToCnnPreProcessor(tensorDimOneSize, tensorDimTwoSize , tensorDimThreeSize))*/
                .inputPreProcessor("dis_dense_layer_1", new ReshapePreprocessor(new long[]{1,tensorDimOneSize * tensorDimTwoSize * tensorDimThreeSize},new long[]{1,tensorDimOneSize * tensorDimTwoSize * tensorDimThreeSize},true ))
                .addLayer("dis_dense_layer_1", new DenseLayer.Builder()
                        .updater(Adam.builder().learningRate(dis_learning_rate).beta1(0.5).build())
                        .nOut(512).activation(new ActivationLReLU(0.2)).build(), "input2")
                .addLayer("dis_dense_layer_2", new DenseLayer.Builder()
                        .updater(Adam.builder().learningRate(dis_learning_rate).beta1(0.5).build())
                        .nOut(256).activation(new ActivationLReLU(0.2)).build(), "dis_dense_layer_1")
                .addLayer("dis_dense_layer_3", new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
                        .updater(Adam.builder().learningRate(dis_learning_rate).beta1(0.5).build())
                        .nOut(1)
                        .activation(Activation.TANH)
                        .build(), "dis_dense_layer_2")
                .setOutputs("dis_dense_layer_3")
                .build());
        dis.init();
        System.out.println(dis.summary());
        System.out.println(Arrays.toString(dis.output(Nd4j.randn(numGenSamples, numFeatures))[0].shape()));

        log.info("Frozen generator!");
        // gen网络用于生成假数据
        //1.从噪点数据中生成类似真的数据
        //2.learningrate 是固定的
        ComputationGraph gen = new ComputationGraph(new NeuralNetConfiguration.Builder()
               /* .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                .inferenceWorkspaceMode(WorkspaceMode.ENABLED)*/
                .seed(numberOfTheBeast)
             /*   .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
                .gradientNormalizationThreshold(1.0)*/
                //.l2(0.0001)
                .activation(new ActivationLReLU(0.2))
                .weightInit(WeightInit.XAVIER)
                .graphBuilder()
                .addInputs("input1")
                .setInputTypes(InputType.feedForward(zSize))
                .addLayer("gen_dense_layer_1", new DenseLayer.Builder()
                        .updater(Adam.builder().learningRate(frozen_learning_rate).beta1(0.5).build())
                        .nOut(256).weightInit(WeightInit.XAVIER).activation(new ActivationLReLU(0.2)).build(), "input1")
                .addLayer("gen_batch_1", new BatchNormalization.Builder()
                        .updater(Adam.builder().learningRate(frozen_learning_rate).beta1(0.5).build()).build(), "gen_dense_layer_1")
                .addLayer("gen_dense_layer_2", new DenseLayer.Builder()
                        .updater(Adam.builder().learningRate(frozen_learning_rate).beta1(0.5).build()).nOut(512).activation(new ActivationLReLU(0.2)).build(), "gen_batch_1")
                .addLayer("gen_batch_2", new BatchNormalization.Builder()
                        .updater(Adam.builder().learningRate(frozen_learning_rate).beta1(0.5).build()).build(), "gen_dense_layer_2")
                .addLayer("gen_dense_layer_3", new DenseLayer.Builder()
                        .updater(Adam.builder().learningRate(frozen_learning_rate).beta1(0.5).build()).nOut(1024).activation(new ActivationLReLU(0.2)).build(), "gen_batch_2")
                .addLayer("gen_batch_3", new BatchNormalization.Builder()
                        .updater(Adam.builder().learningRate(frozen_learning_rate).beta1(0.5).build()).build(), "gen_dense_layer_3")
                .addLayer("gen_dense_layer_4", new DenseLayer.Builder()
                        .updater(Adam.builder().learningRate(frozen_learning_rate).beta1(0.5).build()).nOut(tensorDimOneSize * tensorDimTwoSize * tensorDimThreeSize).activation(Activation.TANH).build(), "gen_batch_3")
                .addLayer("gen_conv2d_9", new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
                        .nOut(tensorDimOneSize * tensorDimTwoSize * tensorDimThreeSize)
                        .updater(Adam.builder().learningRate(frozen_learning_rate).beta1(0.5).build()).build(),"gen_dense_layer_4")
                .setOutputs("gen_conv2d_9")
                .build());
        gen.init();
        System.out.println(gen.summary());
        System.out.println(Arrays.toString(gen.output(Nd4j.randn(numGenSamples, zSize))[0].shape()));

        log.info("GAN with unfrozen generator and frozen discriminator!");
        // gan 络用于学习生成以假乱真的数据
        //1.从噪点数据中生成类似真的数据
        //2.gen的rate <> 0 ,dis的rate=0
        ComputationGraph gan = new ComputationGraph(new NeuralNetConfiguration.Builder()
               /* .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                .inferenceWorkspaceMode(WorkspaceMode.ENABLED)*/
                .seed(numberOfTheBeast)
             /*   .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)*/
               // .gradientNormalizationThreshold(1.0)
               // .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .updater(UPDATER)
                //.l2(0.0001)
                .graphBuilder()
                .addInputs("gan_input_layer_0")
                .setInputTypes(InputType.feedForward(zSize))
                .addLayer("gan_dense_layer_1", new DenseLayer.Builder()
                        .updater(Adam.builder().learningRate(dis_learning_rate).beta1(0.5).build()).nOut(256).weightInit(WeightInit.XAVIER).activation(new ActivationLReLU(0.2)).build(), "gan_input_layer_0")
                .addLayer("gan_batch_1", new BatchNormalization.Builder()
                        .updater(Adam.builder().learningRate(dis_learning_rate).beta1(0.5).build()).build(), "gan_dense_layer_1")
                .addLayer("gan_dense_layer_2", new DenseLayer.Builder()
                        .updater(Adam.builder().learningRate(dis_learning_rate).beta1(0.5).build()).nOut(512).activation(new ActivationLReLU(0.2)).build(), "gan_batch_1")
                .addLayer("gan_batch_2", new BatchNormalization.Builder()
                        .updater(Adam.builder().learningRate(dis_learning_rate).beta1(0.5).build()).build(), "gan_dense_layer_2")
                .addLayer("gan_dense_layer_3", new DenseLayer.Builder()
                        .updater(Adam.builder().learningRate(dis_learning_rate).beta1(0.5).build()).nOut(1024).activation(new ActivationLReLU(0.2)).build(), "gan_batch_2")
                .addLayer("gan_batch_3", new BatchNormalization.Builder()
                        .updater(Adam.builder().learningRate(dis_learning_rate).beta1(0.5).build()).build(), "gan_dense_layer_3")
                .addLayer("gan_dense_layer_4", new DenseLayer.Builder()
                        .updater(Adam.builder().learningRate(dis_learning_rate).beta1(0.5).build())
                        .nOut(tensorDimOneSize * tensorDimTwoSize * tensorDimThreeSize).activation(Activation.TANH).build(), "gan_batch_3")
                .addLayer("gan_conv2d_9", new DenseLayer.Builder()
                        .updater(Adam.builder().learningRate(dis_learning_rate).beta1(0.5).build())
                        .nOut(tensorDimOneSize * tensorDimTwoSize * tensorDimThreeSize).build(),"gan_dense_layer_4")
               /* .inputPreProcessor("gan_dis_dense_layer_1", new FeedForwardToCnnPreProcessor(tensorDimOneSize, tensorDimTwoSize , tensorDimThreeSize))*/
                .inputPreProcessor("gan_dis_dense_layer_1", new ReshapePreprocessor(new long[]{1,tensorDimOneSize * tensorDimTwoSize * tensorDimThreeSize},new long[]{1,tensorDimOneSize * tensorDimTwoSize * tensorDimThreeSize},true ))
                .addLayer("gan_dis_dense_layer_1", new DenseLayer.Builder()
                        .updater(Adam.builder().learningRate(frozen_learning_rate).beta1(0.5).build())
                        .nOut(512).activation(new ActivationLReLU(0.2)).build(), "gan_conv2d_9")
                .addLayer("gan_dis_dense_layer_2", new DenseLayer.Builder()
                        .updater(Adam.builder().learningRate(frozen_learning_rate).beta1(0.5).build())
                                .nOut(256).activation(new ActivationLReLU(0.2)).build(), "gan_dis_dense_layer_1")
                .addLayer("gan_dis_dense_layer_3", new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
                        .updater(Adam.builder().learningRate(frozen_learning_rate).beta1(0.5).build())
                        .nOut(1)
                        .activation(Activation.TANH)
                        .build(), "gan_dis_dense_layer_2")
                .setOutputs("gan_dis_dense_layer_3")
                .build());
        gan.init();
        System.out.println(gan.summary());
        System.out.println(Arrays.toString(gan.output(Nd4j.randn(numGenSamples, zSize))[0].shape()));

        //gen.setListeners(new PerformanceListener(10, true));
        dis.setListeners(new PerformanceListener(10, true));
        gan.setListeners(new PerformanceListener(10, true));


        INDArray grid = Nd4j.linspace(-1l, 1l, numGenSamples);
        Collection<INDArray> z = new ArrayList<>();
       // log.info("Creating some noise!");
        for (int i = 0; i < numGenSamples; i++) {
            for (int j = 0; j < numGenSamples; j++) {
                z.add(Nd4j.create(new double[]{grid.getDouble(0, i), grid.getDouble(0, j)}));
            }
        }

        int batch_counter = 0;
        DataSet trDataSet;

        DataSetIterator iterTrain = new MnistDataSetIterator(batchSizePerWorker, true, 12345);

        for(int it=0;it<numIterations;it++){
            trDataSet = iterTrain.next();
            INDArray real = trDataSet.getFeatures();
            /*long[] shape = real.shape();
            for(int i=0;i<shape[0];i++){
                real.getRow(i).sub(0.5);
            }*/
            INDArray realLabel = Nd4j.ones(batchSizePerWorker, 1);//.mul(-1);
           /* visualize(new INDArray[]{real});*/
            INDArray noise = Nd4j.rand(batchSizePerWorker, zSize);

            //gen_conv2d_9
            Map<String, INDArray> map1 =  gen.feedForward(noise,false);
            INDArray fakeImg = map1.get("gen_conv2d_9");



            INDArray fakeLabel =  Nd4j.zeros(batchSizePerWorker, 1);
            // Unfrozen discriminator is trying to figure itself out given a frozen generator.


            DataSet realSet = new DataSet(real, realLabel);
            DataSet fakeSet = new DataSet(fakeImg,fakeLabel);

            DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
           // for(int m=0;m<5;m++){
           // log.info("Training discriminator!");
            dis.fit(data);
           // }

            /*INDArray imgDate =  Nd4j.vstack(trDataSet.getFeatures(),gen.output(Nd4j.rand(batchSizePerWorker, zSize).muli(2.0).subi(1.0))[0]);
            INDArray imgLabel =  Nd4j.vstack(Nd4j.ones(batchSizePerWorker, 1).addi(soften_labels_real),Nd4j.zeros(batchSizePerWorker, 1).addi(soften_labels_fake));
            MultiDataSet disSet = new MultiDataSet( new INDArray[] {imgDate},new INDArray[] { imgLabel});
            dis.fit(disSet); */

            // Update GAN's frozen discriminator with unfrozen discriminator.
            gan.getLayer("gan_dis_dense_layer_1").setParam("W", dis.getLayer("dis_dense_layer_1").getParam("W"));
            gan.getLayer("gan_dis_dense_layer_1").setParam("b", dis.getLayer("dis_dense_layer_1").getParam("b"));
            gan.getLayer("gan_dis_dense_layer_2").setParam("W", dis.getLayer("dis_dense_layer_2").getParam("W"));
            gan.getLayer("gan_dis_dense_layer_2").setParam("b", dis.getLayer("dis_dense_layer_2").getParam("b"));
            gan.getLayer("gan_dis_dense_layer_3").setParam("W", dis.getLayer("dis_dense_layer_3").getParam("W"));
            gan.getLayer("gan_dis_dense_layer_3").setParam("b", dis.getLayer("dis_dense_layer_3").getParam("b"));



            // Tell the frozen discriminator that all the fake examples are real examples.
            // [Fake, Real].
            // Unfrozen generator is trying to fool the frozen discriminator.
            // log.info("Training generator!");
            DataSet dataSetD5 = new DataSet(noise, realLabel);
            gan.fit(dataSetD5);

            // Update frozen generator with GAN's unfrozen generator.
            gen.getLayer("gen_dense_layer_1").setParam("W", gan.getLayer("gan_dense_layer_1").getParam("W"));
            gen.getLayer("gen_dense_layer_1").setParam("b", gan.getLayer("gan_dense_layer_1").getParam("b"));

            gen.getLayer("gen_batch_1").setParam("gamma", gan.getLayer("gan_batch_1").getParam("gamma"));
            gen.getLayer("gen_batch_1").setParam("beta", gan.getLayer("gan_batch_1").getParam("beta"));
            gen.getLayer("gen_batch_1").setParam("mean", gan.getLayer("gan_batch_1").getParam("mean"));
            gen.getLayer("gen_batch_1").setParam("log10stdev", gan.getLayer("gan_batch_1").getParam("log10stdev"));


            gen.getLayer("gen_dense_layer_2").setParam("W", gan.getLayer("gan_dense_layer_2").getParam("W"));
            gen.getLayer("gen_dense_layer_2").setParam("b", gan.getLayer("gan_dense_layer_2").getParam("b"));

            gen.getLayer("gen_batch_2").setParam("gamma", gan.getLayer("gan_batch_2").getParam("gamma"));
            gen.getLayer("gen_batch_2").setParam("beta", gan.getLayer("gan_batch_2").getParam("beta"));
            gen.getLayer("gen_batch_2").setParam("mean", gan.getLayer("gan_batch_2").getParam("mean"));
            gen.getLayer("gen_batch_2").setParam("log10stdev", gan.getLayer("gan_batch_2").getParam("log10stdev"));



            gen.getLayer("gen_dense_layer_3").setParam("W", gan.getLayer("gan_dense_layer_3").getParam("W"));
            gen.getLayer("gen_dense_layer_3").setParam("b", gan.getLayer("gan_dense_layer_3").getParam("b"));

            gen.getLayer("gen_batch_3").setParam("gamma", gan.getLayer("gan_batch_3").getParam("gamma"));
            gen.getLayer("gen_batch_3").setParam("beta", gan.getLayer("gan_batch_3").getParam("beta"));
            gen.getLayer("gen_batch_3").setParam("mean", gan.getLayer("gan_batch_3").getParam("mean"));
            gen.getLayer("gen_batch_3").setParam("log10stdev", gan.getLayer("gan_batch_3").getParam("log10stdev"));


            gen.getLayer("gen_dense_layer_4").setParam("W", gan.getLayer("gan_dense_layer_4").getParam("W"));
            gen.getLayer("gen_dense_layer_4").setParam("b", gan.getLayer("gan_dense_layer_4").getParam("b"));

            gen.getLayer("gen_conv2d_9").setParam("W", gan.getLayer("gan_conv2d_9").getParam("W"));
            gen.getLayer("gen_conv2d_9").setParam("b", gan.getLayer("gan_conv2d_9").getParam("b"));
            log.info("Completed Batch {}!", batch_counter++);

            if ((batch_counter % printEvery) == 0) {

               // out = gen.output(Nd4j.vstack(z))[0].reshape(numGenSamples * numGenSamples, numFeatures);


                /*      ShowUtils*/
               // INDArray[] samps =  gen.output(Nd4j.vstack(z));

                Map<String, INDArray> map2 =  gen.feedForward(Nd4j.vstack(z),false);
                INDArray samps = map2.get("gen_conv2d_9");
                INDArray[] sampless = new INDArray[1];
                sampless[0] = samps;
                ShowUtils.visualize(sampless,"ls_gan");

            }

            if (!iterTrain.hasNext()) {
                iterTrain.reset();
            }
        }

        //log.info("Saving models!");
      /*  ModelSerializer.writeModel(dis, new File(resPath + dataSetName + "_dis_model.zip"), true);
        ModelSerializer.writeModel(gan, new File(resPath + dataSetName + "_gan_model.zip"), true);
        ModelSerializer.writeModel(gen, new File(resPath + dataSetName + "_gen_model.zip"), true);*/
        /*ModelSerializer.writeModel(sparkCV.getNetwork(), new File(resPath + dataSetName + "_CV_model.zip"), true);

        tm.deleteTempFiles(sc);*/
    }

    public static INDArray randomWeightedAverage(int batch,INDArray real,INDArray fake){
       // INDArray alpha = Nd4j.rand(new UniformDistribution(0,1),new long[]{batch, 1, 1, 1});// new NDRandom().uniform(32, 1, DataType.FLOAT, new long[]{32, 1, 1, 1});
        INDArray alpha = Nd4j.rand(new UniformDistribution(0,1),new long[]{batch, 1});// new NDRandom().uniform(32, 1, DataType.FLOAT, new long[]{32, 1, 1, 1});
        return (alpha.mul(real)).add((Nd4j.ones(alpha.shape()).sub(alpha)).mul(fake));
    }

    private static JFrame frame;
    private static JPanel panel;

    private static void visualize(INDArray[] samples) {
        if (frame == null) {
            frame = new JFrame();
            frame.setTitle("Viz");
            frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
            frame.setLayout(new BorderLayout());

            panel = new JPanel();

            panel.setLayout(new GridLayout(8, 4, 8, 8));
            frame.add(panel, BorderLayout.CENTER);
            frame.setVisible(true);
        }

        panel.removeAll();
        for (INDArray sample : samples) {
            if(sample == null || sample.size(0) == 0){
                continue;
            }
            long size = sample.size(0);
            if(size > 0 ){
                for(int i=0;i<size;i++){
                    if (i==8) {
                        break;
                    }
                    panel.add(getImage(sample.getRow(i)));
                }
            }
        }

        frame.revalidate();
        frame.pack();
    }

    private static JLabel getImage(INDArray tensor) {
        BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
        //long[] shape = tensor.shape();
        for (int i = 0; i < 784; i++) {
            //System.out.println((255 * (tensor.getDouble(i) + 0.5)));
            bi.getRaster().setSample(i % 28, i / 28, 0, (int) (255 * (tensor.getDouble(i))));
        }
        ImageIcon orig = new ImageIcon(bi);
        Image imageScaled = orig.getImage().getScaledInstance((int) (9 * 28), (int) (9 * 28),
                Image.SCALE_DEFAULT);
        ImageIcon scaled = new ImageIcon(imageScaled);

        return new JLabel(scaled);
    }
    private static JLabel imageFromINDArray(INDArray array) {
       // array = array.reshape(28, 28);
        long[] shape = array.shape();
        int height = (int)shape[2];
        int width = (int)shape[3];
        BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);
        for (int x = 0; x < width; x++) {
            for (int y = 0; y < height; y++) {
                //System.out.println(array.getDouble(0, 0, y, x));
                int gray = (int) ((array.getDouble(0, 0, y, x)  + 1) * 127.5);

                // handle out of bounds pixel values
                gray = Math.min(gray, 255);
                gray = Math.max(gray, 0);

                image.getRaster().setSample(x, y, 0, gray);
            }
        }
        ImageIcon orig = new ImageIcon(image);
        Image imageScaled = orig.getImage().getScaledInstance((int) (9 * 28), (int) (9 * 28),
                Image.SCALE_DEFAULT);
        ImageIcon scaled = new ImageIcon(imageScaled);

        return new JLabel(scaled);
        //return image;
    }
}